
import time
import requests
import queue
import traceback

from collections import deque
from threading import Lock, Thread
from quant import config

from quant.utils import Timer2, logging, catch_exception, EventEngine, Iota


class LagInfoEngine(EventEngine):
    def put_request_step(self, req_id, model, req, step, worker=None):
        self.put('request_step', req_id, model, req, step, worker)

    def put_queue_wait(self, queue_wait):
        self.put('Wait', queue_wait)

    def put_lag(self, model, lag):
        self.put('Lag', model, lag)

    def subscribe_request_step(self, handler):
        self.subscribe('request_step', handler)

    def subscribe_queue_wait(self, handler):
        self.subscribe('Wait', handler)

    def subscribe_lag(self, handler):
        self.subscribe('Lag', handler)

    def show_all(self):
        def on_queue_wait(wait):
            print('wait:', format_time(wait))

        def on_lag(model, lag):
            print('lag:', format_time(lag))

        def format_time(t):
            return '{:.2f}ms'.format( t *1000)

        self.subscribe_queue_wait(on_queue_wait)
        self.subscribe_lag(on_lag)

    def show_all_step(self):
        from quant.utils import LimitDict

        def show_step(record, step):
            lag = time.perf_counter() - record['initiate']
            lag = '{:.2f}'.format(lag * 1000)
            print('{}lag: {}'.format(step, lag))

        def on_step(req_id, model, req, step):
            now = time.perf_counter()
            if step == 'initiate':
                limit_dict[req_id] = {step: now}
            else:
                record = limit_dict.get(req_id)
                if record is not None:
                    show_step(record, step)

        limit_dict = LimitDict(100)
        self.subscribe_request_step(on_step)


class IdCreator:
    _ord_count = 97

    def __init__(self):
        self._head = chr(self._ord_count)
        type(self)._ord_count += 1
        self._count = 0

    def create_id(self):
        self._count += 1
        result = '{}{}'.format(self._head, self._count)
        return result


class SessionWrap:
    def __init__(self, pool):
        self._pool = pool
        self.id = pool.id_creator.create_id()
        self.session = requests.Session()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None or exc_tb is not None:
            traceback.print_exc()
        self._pool.put(self)
        return True


class SessionPool:
    def __init__(self, id_creator):
        self.id_creator = id_creator
        self._lock = Lock()
        self._sessions = deque()

    def get(self):
        with self._lock:
            if len(self._sessions) == 0:
                session = SessionWrap(self)
                logging.info('create session({})'.format(session.id))
                return session

            result = self._sessions.popleft()
            return result

    def put(self, session):
        with self._lock:
            self._sessions.append(session)


class Requesting:
    lag_info_engine = LagInfoEngine()

    def __init__(self, workers=30):
        self._queue = queue.Queue()
        self._pool = SessionPool(IdCreator())
        self._all_thread = self._start_thread(workers)
        Timer2(self._check_thread_alive, check_interval).start()

    def request_async(self, model, req, callback=None):
        req_id = req_id_iota.next()
        self.lag_info_engine.put_request_step(req_id, model, req, 'initiate')
        self._queue.put((req_id, model, req, callback, time.perf_counter()))
        return req_id

    def request(self, req):
        s = requests.Session()
        try:
            resp = s.send(req.prepare())
            return resp
        except Exception as err:
            logging.error(str(err))

    def join(self):
        return self._queue.join()

    def _start_thread(self, count):
        all_thread = []
        for i in range(count):
            t = Thread(target=self._run, daemon=True)
            t.start()
            all_thread.append(t)
        return all_thread

    def _run(self):
        while True:
            req_id, model, req, callback, quest_time = self._queue.get()

            with self._pool.get() as wrap:
                session = wrap.session
                count = wrap.id

                self.lag_info_engine.put_request_step(req_id, model, req, 'queue got', count)
                now = time.perf_counter()
                wait = now - quest_time

                self.lag_info_engine.put_queue_wait(wait)
                if wait > too_long_in_queue:
                    logging.error('request blocked in queue for {:.3f}ms'.format(wait *1000))

                try:
                    prep = session.prepare_request(req)
                    self.lag_info_engine.put_request_step(req_id, model, req, 'request prepared', count)
                    resp = session.send(prep, timeout=config.rest_timeout)
                    self.lag_info_engine.put_request_step(req_id, model, req, 'respond received', count)
                except Exception as err:
                    self.lag_info_engine.put_request_step(req_id, model, req, 'request fail', count)
                    resp = self._build_fail_response(err)
                finally:
                    lag = time.perf_counter() - now
                    self.lag_info_engine.put_lag(model, lag)

                    self._put_response(model, resp, callback)
                    self._queue.task_done()

    def _put_response(self, model, response, callback):
        if callback is None:
            return
        try:
            callback(model, response)
        except:
            catch_exception()

    def _build_fail_response(self, err):
        try:
            content = '"{}"'.format(err)
            content = content.encode()
        except Exception as _:
            err = 'Requesting._build_fail_response() err, cannot encode {}'.format(err)
            logging.error(err)
            content = b'{}'

        resp = requests.Response()
        resp._content = content
        resp.status_code = -1
        return resp

    def _check_thread_alive(self):
        alive_list = [t.is_alive() for t in self._all_thread]
        if not all(alive_list):
            logging.error('Request thread die: {}'.format(str(alive_list)))


class SessonSelecting:
    pass



check_interval = 100
too_long_in_queue = 30 / 1000  # seconds

req_id_iota = Iota()


if __name__ == '__main__':
    rq = Requesting()
    rq2 = Requesting()

    a = rq._pool.get()
    c = rq2._pool.get()
    b = rq._pool.get()
    d = rq2._pool.get()







